# Copyright 2020 Fabio Tosi, Filippo Aleotti, Pierluigi Zama Ramirez, Matteo Poggi,
# Samuele Salti, Luigi Di Stefano, Stefano Mattoccia
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
General network, superclass for other networks
"""

from abc import ABCMeta, abstractmethod
import tensorflow as tf
from helpers.utilities import get_num_classes, get_height_width, colormap_semantic
from collections import namedtuple

network_parameters = namedtuple(
    "network_parameters", "height, width, load_only_baseline, tau",
)


class GeneralNetwork(object):
    """Template for other networks
    """

    __metaclass__ = ABCMeta

    def __init__(self, batch, is_training, params):
        """ Prepare the network and create the graph"""
        self.is_training = is_training
        self.classes = get_num_classes()
        self.params = params
        self.src_img_1 = batch["src_img_1"]
        self.tgt_img = batch["tgt_img"]
        self.src_img_2 = batch["src_img_2"]
        self.h, self.w = get_height_width(self.tgt_img)
        self.batch = batch

    def build(self):
        """ Build the model and the outputs """
        self.build_network()
        self.build_outputs()

    @abstractmethod
    def build_network(self):
        """ Network specification"""

    @abstractmethod
    def build_outputs(self):
        """ Output generated by the network. """

    def build_masks(self):
        """ Build masks used in the stage """
